{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6459ddb8",
   "metadata": {},
   "source": [
    "Copyright (c) MONAI Consortium  \n",
    "Licensed under the Apache License, Version 2.0 (the \"License\");  \n",
    "you may not use this file except in compliance with the License.  \n",
    "You may obtain a copy of the License at  \n",
    "&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  \n",
    "Unless required by applicable law or agreed to in writing, software  \n",
    "distributed under the License is distributed on an \"AS IS\" BASIS,  \n",
    "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  \n",
    "See the License for the specific language governing permissions and  \n",
    "limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c5bf94b-4f1d-4e94-a6e6-94644095e910",
   "metadata": {},
   "source": [
    "# MONAI Dice Overview\n",
    "\n",
    "This notebook summarises some of the details relating to the `DiceLoss` and `DiceMetric` classes and their behaviour."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "952771b9",
   "metadata": {},
   "source": [
    "## Setup environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b7eb650",
   "metadata": {},
   "outputs": [],
   "source": [
    "!python -c \"import monai\" || pip install -q \"monai-weekly\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55f81921",
   "metadata": {},
   "source": [
    "## Setup imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "16dc1913-abd6-48f4-a698-1f728aad742c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from monai.config import print_config\n",
    "from monai.losses import DiceLoss\n",
    "from monai.metrics import DiceMetric\n",
    "from monai.transforms import AsDiscrete, Compose\n",
    "\n",
    "print_config()\n",
    "\n",
    "\n",
    "def print_tensor(name, t):\n",
    "    print(f\"{name}: {t.numpy().tolist()} shape: {tuple(t.shape)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e155189f-f3a4-4353-9819-a6deec8ee331",
   "metadata": {},
   "source": [
    "## Binary Formulation\n",
    "\n",
    "Start with the binary case: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "754a92d2-ff9b-4f3d-8605-5c8a5a0e9256",
   "metadata": {},
   "outputs": [],
   "source": [
    "# CHW\n",
    "grnd = torch.zeros(1, 4, 4)\n",
    "pred = torch.zeros(1, 4, 4)\n",
    "\n",
    "# grnd= [0,0,0,0] pred= [0,0,0,0]\n",
    "#       [0,1,1,0]       [0,1,1,0]\n",
    "#       [0,1,1,0]       [1,1,0,0]\n",
    "#       [0,0,0,0]       [0,0,0,0]\n",
    "grnd[..., 1, 1] = grnd[..., 1, 2] = grnd[..., 2, 1] = grnd[..., 2, 2] = 1\n",
    "pred[..., 1, 1] = pred[..., 1, 2] = pred[..., 2, 0] = pred[..., 2, 1] = 1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f5e5923-e4e4-444d-b227-f1abd92e6e7c",
   "metadata": {},
   "source": [
    "Given this ground truth and prediction we would expect a dice loss of 0.25 and so a dice score of 0.75. The prediction correctly labels 3 pixels as foreground out of 4 with 1 pixel of oversegmentation and 1 of undersegmentation. \n",
    "\n",
    "By default the loss and metric will work on the binary case and produce the expected results, that is calculate the dice loss/metric for the foreground:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "8771ec62-b468-4235-a946-fba899dafe91",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss: 0.24999964237213135 shape: ()\n",
      "metric: [[0.75]] shape: (1, 1)\n"
     ]
    }
   ],
   "source": [
    "loss = DiceLoss()\n",
    "metric = DiceMetric()\n",
    "\n",
    "# using [None] to add batch dimension\n",
    "print_tensor(\"loss\", loss(pred[None], grnd[None]))\n",
    "print_tensor(\"metric\", metric(pred[None], grnd[None]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a45677cb-108a-4c17-84f9-cb8d3b1fa3ea",
   "metadata": {},
   "source": [
    "The loss assumes by default that the input has been activated already. If the argument `sigmoid` is set to True sigmoid activation is applied to the prediction, similarly if `softmax` is True softmax is applied. Our fake predictions used here are treated as if they have been through activation already. The metric also assumes that the prediction given to it has been activated.\n",
    "\n",
    "This binary data can be treated instead as a two class segmentation with background and foreground as the classes. This is less efficient than using the loss and metric in the binary form as shown above but does produce the correct results. If the `include_background` arguments for both are True (the default) and `reduction` is `None` we can see the loss/scores for both classes in the outputs:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ed3930b1-959e-426d-b5d5-4d5e7727fca0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# make one hot and add batch dimension\n",
    "make_2_class = Compose([AsDiscrete(to_onehot=2), lambda x: x[None]])\n",
    "\n",
    "grnd2 = make_2_class(grnd)\n",
    "pred2 = make_2_class(pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "cde8ff6f-0b29-43bf-8578-013c02d6b519",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss2: [[[[0.08333331346511841]], [[0.24999964237213135]]]] shape: (1, 2, 1, 1)\n",
      "metric2: [[0.9166666865348816, 0.75]] shape: (1, 2)\n"
     ]
    }
   ],
   "source": [
    "loss2 = DiceLoss(include_background=True, reduction=\"none\")\n",
    "metric2 = DiceMetric(include_background=True, reduction=\"none\")\n",
    "\n",
    "print_tensor(\"loss2\", loss2(pred2, grnd2))\n",
    "print_tensor(\"metric2\", metric2(pred2, grnd2))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be190e4d-dcb1-4825-bda3-d13429c3d455",
   "metadata": {},
   "source": [
    "The second values correspond to the foreground loss/score and match up with what was seen above.\n",
    "\n",
    "With the loss the ground truth can be provided as a flat segmentation map if `to_onehot_y` is True:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "12b62378-5c4b-462e-844f-a6c3212b1dc6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss_onehot: [[[[0.08333331346511841]], [[0.24999964237213135]]]] shape: (1, 2, 1, 1)\n"
     ]
    }
   ],
   "source": [
    "loss_onehot = DiceLoss(include_background=True, to_onehot_y=True, reduction=\"none\")\n",
    "\n",
    "# note grnd[None] instead of grnd2\n",
    "print_tensor(\"loss_onehot\", loss_onehot(pred2, grnd[None]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c736e6a-b6fa-4863-a9d1-5d688ed4f5f5",
   "metadata": {},
   "source": [
    "If `include_background` is False then class 0 is discarded from the loss/score calculation. This is useful for loss functions when the background class may dominate the calculation and lead the network to optimise by just ignoring small segmentation classes, but probably this only makes sense in true multi-class situations:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "11748619-02c0-48df-8543-9eca5c1d3d48",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss3: [[[[0.24999964237213135]]]] shape: (1, 1, 1, 1)\n",
      "metric3: [[0.75]] shape: (1, 1)\n"
     ]
    }
   ],
   "source": [
    "loss3 = DiceLoss(include_background=False, reduction=\"none\")\n",
    "metric3 = DiceMetric(include_background=False, reduction=\"none\")\n",
    "\n",
    "print_tensor(\"loss3\", loss3(pred2, grnd2))\n",
    "print_tensor(\"metric3\", metric3(pred2, grnd2))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "80544ab5-bf2d-4d0d-8cc9-1906b17a620d",
   "metadata": {},
   "source": [
    "The results here are the same values as with the binary for of the loss and metric just different shape. If `include_background` is True and `reduction` is `mean` the loss/score looks different for the same data because the values for the background are included and averaged with that of the foreground. A pixel/voxel can only be fore- or background (ie. the relationship between foreground and background is exclusive) so the loss/metric values for each category improve or worsen in lockstep and provide no added signal. It's therefore unnecessary to calculate this and leads to poor training:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "ea7806fa-ddee-49d0-a16a-aacbc9139533",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss4: 0.16666647791862488 shape: ()\n",
      "metric4: [[0.9166666865348816, 0.75]] shape: (1, 2)\n"
     ]
    }
   ],
   "source": [
    "loss4 = DiceLoss(include_background=True, reduction=\"mean\")\n",
    "metric4 = DiceMetric(include_background=True, reduction=\"mean\")\n",
    "\n",
    "print_tensor(\"loss4\", loss4(pred2, grnd2))\n",
    "print_tensor(\"metric4\", metric4(pred2, grnd2))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27789c2f-8c76-47ec-bb8d-bc8208b8bc77",
   "metadata": {},
   "source": [
    "## Multi-class Formulation\n",
    "\n",
    "Onto the multi-class case, in this example 3 classes where class 1 is perfectly segmented:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "fbad0456-cb89-4f4d-93bb-a39a5fde05c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "mgrnd = torch.zeros(1, 4, 4)\n",
    "mpred = torch.zeros(1, 4, 4)\n",
    "\n",
    "# grnd= [0,0,0,0] pred= [0,0,0,0]\n",
    "#       [0,1,1,0]       [0,1,1,0]\n",
    "#       [0,2,2,0]       [2,2,0,0]\n",
    "#       [0,0,0,0]       [0,0,0,0]\n",
    "mgrnd[..., 1, 1] = mgrnd[..., 1, 2] = 1\n",
    "mgrnd[..., 2, 1] = mgrnd[..., 2, 2] = 2\n",
    "\n",
    "mpred[..., 1, 1] = mpred[..., 1, 2] = 1\n",
    "mpred[..., 2, 0] = mpred[..., 2, 1] = 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "21ac58d8-233f-43f5-aace-86ef81da9236",
   "metadata": {},
   "source": [
    "Using the binary formulation with these values will lead to weird results:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "40c3393e-1b34-4cd3-8a67-e3c91e6f266f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss: 0.0 shape: ()\n",
      "metric: [[1.0]] shape: (1, 1)\n"
     ]
    }
   ],
   "source": [
    "print_tensor(\"loss\", loss(mpred[None], mgrnd[None]))\n",
    "print_tensor(\"metric\", metric(mpred[None], mgrnd[None]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "122736d4-f0fd-48ff-a644-ee192cae59aa",
   "metadata": {},
   "source": [
    "Instead the prediction and ground truth should be in one-hot format:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "53412728-f7be-4d35-a8cf-380b1c97bb1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# make one hot and add batch dimension\n",
    "make_3_class = Compose([AsDiscrete(to_onehot=3), lambda x: x[None]])\n",
    "\n",
    "mgrnd2 = make_3_class(mgrnd)\n",
    "mpred2 = make_3_class(mpred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "2305a434-ab0e-4176-b6a6-313fb0b92ca3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss: 0.19444401562213898 shape: ()\n",
      "metric: [[0.9166666865348816, 1.0, 0.5]] shape: (1, 3)\n"
     ]
    }
   ],
   "source": [
    "print_tensor(\"loss\", loss(mpred2, mgrnd2))\n",
    "print_tensor(\"metric\", metric(mpred2, mgrnd2))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "498c30d7-1a1c-47dc-a573-9b924e5a31d1",
   "metadata": {},
   "source": [
    "Once again note that the assumption is the inputs for both the loss and metric have been through activation already and are in one-hot format. If the prediction hasn't been activated `softmax` can be set to True to apply this activation, and `to_onehot_y` can be set to True to convert the ground truth to one-hot. \n",
    "\n",
    "The design decision in MONAI is not to integrate activation into the final layer of networks and to use post-processing transforms to apply activation to predictions or use loss function arguments. Be aware of what your network is producing as its final output and whether this should be passed through sigmoid or softmax before calculating loss or metric.\n",
    "\n",
    "Disabling reduction allows us to see the loss/metric scores for every class:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "805731fe-8580-4378-9a7b-81bfde060208",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss5: [[[[0.08333331346511841]], [[0.0]], [[0.4999987483024597]]]] shape: (1, 3, 1, 1)\n",
      "metric5: [[0.9166666865348816, 1.0, 0.5]] shape: (1, 3)\n"
     ]
    }
   ],
   "source": [
    "loss5 = DiceLoss(reduction=\"none\")\n",
    "metric5 = DiceMetric(reduction=\"none\")\n",
    "\n",
    "print_tensor(\"loss5\", loss5(mpred2, mgrnd2))\n",
    "print_tensor(\"metric5\", metric5(mpred2, mgrnd2))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3cad0bc0-aeba-4ac4-9701-9a927e828675",
   "metadata": {},
   "source": [
    "From both outputs we can see that the background is almost perfectly predicted, class 1 is in fact perfect, and class 2 is off by half (ie. 1 pixel out of 2 is correct). These examples neglect the case of correctly segmenting a pixel but with the wrong category but this isn't really necessary here.\n",
    "\n",
    "If the ground truth isn't one-hot the loss function can be changed to do that conversion:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "a1a5a8c5-7da5-4ca7-a196-3cb8428609d8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss6: [[[[0.08333331346511841]], [[0.0]], [[0.4999987483024597]]]] shape: (1, 3, 1, 1)\n"
     ]
    }
   ],
   "source": [
    "loss6 = DiceLoss(reduction=\"none\", to_onehot_y=True)\n",
    "print_tensor(\"loss6\", loss6(mpred2, mgrnd[None]))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
